{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import time\n",
    "import cv2\n",
    "import math\n",
    "import copy\n",
    "import onnxruntime\n",
    "import numpy as np\n",
    "import pyclipper\n",
    "from shapely.geometry import Polygon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class NormalizeImage(object):\n",
    "    \"\"\" normalize image such as substract mean, divide std\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs):\n",
    "        if isinstance(scale, str):\n",
    "            scale = eval(scale)\n",
    "        self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)\n",
    "        mean = mean if mean is not None else [0.485, 0.456, 0.406]\n",
    "        std = std if std is not None else [0.229, 0.224, 0.225]\n",
    "\n",
    "        shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)\n",
    "        self.mean = np.array(mean).reshape(shape).astype('float32')\n",
    "        self.std = np.array(std).reshape(shape).astype('float32')\n",
    "\n",
    "    def __call__(self, data):\n",
    "        img = data['image']\n",
    "        from PIL import Image\n",
    "        if isinstance(img, Image.Image):\n",
    "            img = np.array(img)\n",
    "\n",
    "        assert isinstance(img,\n",
    "                          np.ndarray), \"invalid input 'img' in NormalizeImage\"\n",
    "        data['image'] = (\n",
    "            img.astype('float32') * self.scale - self.mean) / self.std\n",
    "        return data\n",
    "\n",
    "\n",
    "class ToCHWImage(object):\n",
    "    \"\"\" convert hwc image to chw image\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, **kwargs):\n",
    "        pass\n",
    "\n",
    "    def __call__(self, data):\n",
    "        img = data['image']\n",
    "        from PIL import Image\n",
    "        if isinstance(img, Image.Image):\n",
    "            img = np.array(img)\n",
    "        data['image'] = img.transpose((2, 0, 1))\n",
    "        return data\n",
    "\n",
    "\n",
    "class KeepKeys(object):\n",
    "    def __init__(self, keep_keys, **kwargs):\n",
    "        self.keep_keys = keep_keys\n",
    "\n",
    "    def __call__(self, data):\n",
    "        data_list = []\n",
    "        for key in self.keep_keys:\n",
    "            data_list.append(data[key])\n",
    "        return data_list\n",
    "\n",
    "class DetResizeForTest(object):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(DetResizeForTest, self).__init__()\n",
    "        self.resize_type = 0\n",
    "        self.keep_ratio = False\n",
    "        if 'image_shape' in kwargs:\n",
    "            self.image_shape = kwargs['image_shape']\n",
    "            self.resize_type = 1\n",
    "            if 'keep_ratio' in kwargs:\n",
    "                self.keep_ratio = kwargs['keep_ratio']\n",
    "        elif 'limit_side_len' in kwargs:\n",
    "            self.limit_side_len = kwargs['limit_side_len']\n",
    "            self.limit_type = kwargs.get('limit_type', 'min')\n",
    "        elif 'resize_long' in kwargs:\n",
    "            self.resize_type = 2\n",
    "            self.resize_long = kwargs.get('resize_long', 640)\n",
    "        else:\n",
    "            self.limit_side_len = 736\n",
    "            self.limit_type = 'min'\n",
    "\n",
    "    def __call__(self, data):\n",
    "        img = data['image']\n",
    "        src_h, src_w, _ = img.shape\n",
    "        #print(self.resize_type)\n",
    "        if sum([src_h, src_w]) < 64:\n",
    "            img = self.image_padding(img)\n",
    "        if self.resize_type == 0:\n",
    "            # img, shape = self.resize_image_type0(img)\n",
    "            img, [ratio_h, ratio_w] = self.resize_image_type0(img)\n",
    "        elif self.resize_type == 2:\n",
    "            img, [ratio_h, ratio_w] = self.resize_image_type2(img)\n",
    "        else:\n",
    "            # img, shape = self.resize_image_type1(img)\n",
    "            img, [ratio_h, ratio_w] = self.resize_image_type1(img)\n",
    "        data['image'] = img\n",
    "        data['shape'] = np.array([src_h, src_w, ratio_h, ratio_w])\n",
    "        return data\n",
    "\n",
    "    def image_padding(self, im, value=0):\n",
    "        h, w, c = im.shape\n",
    "        im_pad = np.zeros((max(32, h), max(32, w), c), np.uint8) + value\n",
    "        im_pad[:h, :w, :] = im\n",
    "        return im_pad\n",
    "\n",
    "    def image_padding_640(self, im, value=0):\n",
    "        h, w, c = im.shape\n",
    "        im_pad = np.zeros((max(640, h), max(640, w), c), np.uint8) + value\n",
    "        im_pad[:h, :w, :] = im\n",
    "        return im_pad\n",
    "\n",
    "    def resize_image_type1(self, img):\n",
    "        resize_h, resize_w = self.image_shape\n",
    "        ori_h, ori_w = img.shape[:2]  # (h, w, c)\n",
    "        if self.keep_ratio is True:\n",
    "            resize_w = ori_w * resize_h / ori_h\n",
    "            N = math.ceil(resize_w / 32)\n",
    "            resize_w = N * 32\n",
    "        ratio_h = float(resize_h) / ori_h\n",
    "        ratio_w = float(resize_w) / ori_w\n",
    "      \n",
    "        img = cv2.resize(img, (int(resize_w), int(resize_h)))\n",
    "        # return img, np.array([ori_h, ori_w])\n",
    "        return img, [ratio_h, ratio_w]\n",
    "\n",
    "    def resize_image_type0(self, img):\n",
    "        \"\"\"\n",
    "        resize image to a size multiple of 32 which is required by the network\n",
    "        args:\n",
    "            img(array): array with shape [h, w, c]\n",
    "        return(tuple):\n",
    "            img, (ratio_h, ratio_w)\n",
    "        \"\"\"\n",
    "        limit_side_len = self.limit_side_len\n",
    "        h, w, c = img.shape\n",
    "\n",
    "        # limit the max side\n",
    "        if self.limit_type == 'max':\n",
    "            if max(h, w) > limit_side_len:\n",
    "                if h > w:\n",
    "                    ratio = float(limit_side_len) / h\n",
    "                else:\n",
    "                    ratio = float(limit_side_len) / w\n",
    "            else:\n",
    "                ratio = 1.\n",
    "        elif self.limit_type == 'min':\n",
    "            if min(h, w) < limit_side_len:\n",
    "                if h < w:\n",
    "                    ratio = float(limit_side_len) / h\n",
    "                else:\n",
    "                    ratio = float(limit_side_len) / w\n",
    "            else:\n",
    "                ratio = 1.\n",
    "        elif self.limit_type == 'resize_long':\n",
    "            ratio = float(limit_side_len) / max(h, w)\n",
    "        else:\n",
    "            raise Exception('not support limit type, image ')\n",
    "        resize_h = int(h * ratio)\n",
    "        resize_w = int(w * ratio)\n",
    "\n",
    "        resize_h = max(int(round(resize_h / 32) * 32), 32)\n",
    "        resize_w = max(int(round(resize_w / 32) * 32), 32)\n",
    "\n",
    "        try:\n",
    "            if int(resize_w) <= 0 or int(resize_h) <= 0:\n",
    "                return None, (None, None)\n",
    "            img = cv2.resize(img, (int(resize_w), int(resize_h)))\n",
    "        except:\n",
    "            print(img.shape, resize_w, resize_h)\n",
    "            sys.exit(0)\n",
    "        ratio_h = resize_h / float(h)\n",
    "        ratio_w = resize_w / float(w)\n",
    "        return img, [ratio_h, ratio_w]\n",
    "\n",
    "    def resize_image_type2(self, img):\n",
    "        h, w, _ = img.shape\n",
    "\n",
    "        resize_w = w\n",
    "        resize_h = h\n",
    "\n",
    "        if resize_h > resize_w:\n",
    "            ratio = float(self.resize_long) / resize_h\n",
    "        else:\n",
    "            ratio = float(self.resize_long) / resize_w\n",
    "\n",
    "        resize_h = int(resize_h * ratio)\n",
    "        resize_w = int(resize_w * ratio)\n",
    "        img = cv2.resize(img, (resize_w, resize_h))\n",
    "        #这里加个0填充，适配固定shape模型\n",
    "        img = self.image_padding_640(img)\n",
    "        return img, [ratio, ratio]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "### 检测结果后处理过程（得到检测框）\n",
    "class DBPostProcess(object):\n",
    "    \"\"\"\n",
    "    The post process for Differentiable Binarization (DB).\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self,\n",
    "                 thresh=0.3,\n",
    "                 box_thresh=0.7,\n",
    "                 max_candidates=1000,\n",
    "                 unclip_ratio=2.0,\n",
    "                 use_dilation=False,\n",
    "                 **kwargs):\n",
    "        self.thresh = thresh\n",
    "        self.box_thresh = box_thresh\n",
    "        self.max_candidates = max_candidates\n",
    "        self.unclip_ratio = unclip_ratio\n",
    "        self.min_size = 3\n",
    "        self.dilation_kernel = np.array([[1, 1], [1, 1]]) if use_dilation else None\n",
    "\n",
    "    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height,ratio):\n",
    "        '''\n",
    "        _bitmap: single map with shape (1, H, W),\n",
    "                whose values are binarized as {0, 1}\n",
    "        '''\n",
    "        bitmap = _bitmap\n",
    "        #height, width = bitmap.shape  \n",
    "        height, width = dest_height*ratio, dest_width*ratio,\n",
    "    \n",
    "        outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,\n",
    "                                cv2.CHAIN_APPROX_SIMPLE)\n",
    "        if len(outs) == 3:\n",
    "            img, contours, _ = outs[0], outs[1], outs[2]\n",
    "        elif len(outs) == 2:\n",
    "            contours, _ = outs[0], outs[1]\n",
    "\n",
    "        num_contours = min(len(contours), self.max_candidates)\n",
    "\n",
    "        boxes = []\n",
    "        scores = []\n",
    "        for index in range(num_contours):\n",
    "            contour = contours[index]\n",
    "            points, sside = self.get_mini_boxes(contour)\n",
    "            if sside < self.min_size:\n",
    "                continue\n",
    "            points = np.array(points)\n",
    "            score = self.box_score_fast(pred, points.reshape(-1, 2))\n",
    "            if self.box_thresh > score:\n",
    "                continue\n",
    "\n",
    "            box = self.unclip(points).reshape(-1, 1, 2)\n",
    "            box, sside = self.get_mini_boxes(box)\n",
    "            if sside < self.min_size + 2:\n",
    "                continue\n",
    "            box = np.array(box)\n",
    "\n",
    "            box[:, 0] = np.clip(   # 640 * 661\n",
    "                np.round(box[:, 0] / width * dest_width), 0, dest_width)\n",
    "            box[:, 1] = np.clip(\n",
    "                np.round(box[:, 1] / height * dest_height), 0, dest_height)\n",
    "            boxes.append(box.astype(np.int16))\n",
    "            scores.append(score)\n",
    "        return np.array(boxes, dtype=np.int16), scores\n",
    "\n",
    "    def unclip(self, box):\n",
    "        unclip_ratio = self.unclip_ratio\n",
    "        poly = Polygon(box)\n",
    "        distance = poly.area * unclip_ratio / poly.length\n",
    "        offset = pyclipper.PyclipperOffset()\n",
    "        offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)\n",
    "        return np.array(offset.Execute(distance))\n",
    "\n",
    "    def get_mini_boxes(self, contour):\n",
    "        bounding_box = cv2.minAreaRect(contour)\n",
    "        points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])\n",
    "\n",
    "        index_1, index_2, index_3, index_4 = 0, 1, 2, 3\n",
    "        if points[1][1] > points[0][1]:\n",
    "            index_1 = 0\n",
    "            index_4 = 1\n",
    "        else:\n",
    "            index_1 = 1\n",
    "            index_4 = 0\n",
    "        if points[3][1] > points[2][1]:\n",
    "            index_2 = 2\n",
    "            index_3 = 3\n",
    "        else:\n",
    "            index_2 = 3\n",
    "            index_3 = 2\n",
    "\n",
    "        box = [\n",
    "            points[index_1], points[index_2], points[index_3], points[index_4]\n",
    "        ]\n",
    "        return box, min(bounding_box[1])\n",
    "\n",
    "    def box_score_fast(self, bitmap, _box):\n",
    "        h, w = bitmap.shape[:2]\n",
    "        box = _box.copy()\n",
    "        xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1)\n",
    "        xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1)\n",
    "        ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1)\n",
    "        ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1)\n",
    "\n",
    "        mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)\n",
    "        box[:, 0] = box[:, 0] - xmin\n",
    "        box[:, 1] = box[:, 1] - ymin\n",
    "        cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)\n",
    "        return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]\n",
    "\n",
    "    def __call__(self, outs_dict, shape_list):\n",
    "        pred = outs_dict\n",
    "        pred = pred[:, 0, :, :]\n",
    "        segmentation = pred > self.thresh\n",
    "        boxes_batch = []\n",
    "        for batch_index in range(pred.shape[0]):\n",
    "            src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]\n",
    "            if self.dilation_kernel is not None:\n",
    "                mask = cv2.dilate(\n",
    "                    np.array(segmentation[batch_index]).astype(np.uint8),\n",
    "                    self.dilation_kernel)\n",
    "            else:\n",
    "                mask = segmentation[batch_index]\n",
    "            boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,\n",
    "                                                   src_w, src_h,ratio_w)\n",
    "            boxes_batch.append({'points': boxes})\n",
    "        return boxes_batch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "## 根据推理结果解码识别结果\n",
    "class process_pred(object):\n",
    "    def __init__(self, character_dict_path=None, character_type='ch', use_space_char=False):\n",
    "        self.character_str = ''\n",
    "        with open(character_dict_path, 'rb') as fin:\n",
    "            lines = fin.readlines()\n",
    "            for line in lines:\n",
    "                line = line.decode('utf-8').strip('\\n').strip('\\r\\n')\n",
    "                self.character_str += line\n",
    "        if use_space_char:\n",
    "            self.character_str += ' '\n",
    "        dict_character = list(self.character_str)\n",
    "\n",
    "        dict_character = self.add_special_char(dict_character)\n",
    "        self.dict = {char: i for i, char in enumerate(dict_character)}\n",
    "        self.character = dict_character\n",
    "\n",
    "    def add_special_char(self, dict_character):\n",
    "        dict_character = ['blank'] + dict_character\n",
    "        return dict_character\n",
    "\n",
    "    def decode(self, text_index, text_prob=None, is_remove_duplicate=False):\n",
    "        result_list = []\n",
    "        ignored_tokens = [0]\n",
    "        batch_size = len(text_index)\n",
    "        for batch_idx in range(batch_size):\n",
    "            char_list = []\n",
    "            conf_list = []\n",
    "            for idx in range(len(text_index[batch_idx])):\n",
    "                if text_index[batch_idx][idx] in ignored_tokens:\n",
    "                    continue\n",
    "                if is_remove_duplicate and idx > 0 and text_index[batch_idx][idx - 1] == text_index[batch_idx][idx]:\n",
    "                    continue\n",
    "                char_list.append(self.character[int(text_index[batch_idx][idx])])\n",
    "                if text_prob is not None:\n",
    "                    conf_list.append(text_prob[batch_idx][idx])\n",
    "                else:\n",
    "                    conf_list.append(1)\n",
    "            text = ''.join(char_list)\n",
    "            result_list.append((text, np.mean(conf_list)))\n",
    "        return result_list\n",
    "\n",
    "    def __call__(self, preds, label=None):\n",
    "        if not isinstance(preds, np.ndarray):\n",
    "            preds = np.array(preds)\n",
    "        preds_idx = preds.argmax(axis=2)\n",
    "        preds_prob = preds.max(axis=2)\n",
    "        text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)\n",
    "        if label is None:\n",
    "            return text\n",
    "        label = self.decode(label)\n",
    "        return text, label\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "class det_rec_functions(object):\n",
    "    def __init__(self, image,use_dnn = False):\n",
    "        self.img = image.copy()\n",
    "        self.det_file = './det_model.onnx'\n",
    "        self.small_rec_file = './rec_model.onnx'\n",
    "        self.model_shape = [3,32,1000]\n",
    "        self.use_dnn = use_dnn\n",
    "        if self.use_dnn == False:\n",
    "            self.onet_det_session = onnxruntime.InferenceSession(self.det_file)\n",
    "            self.onet_rec_session = onnxruntime.InferenceSession(self.small_rec_file)\n",
    "        else:\n",
    "            self.onet_det_session = cv2.dnn.readNetFromONNX(self.det_file)\n",
    "            self.onet_rec_session = cv2.dnn.readNetFromONNX(self.small_rec_file)\n",
    "        self.infer_before_process_op, self.det_re_process_op = self.get_process()\n",
    "        self.postprocess_op = process_pred('./ppocr_keys_v1.txt', 'ch', True)\n",
    "\n",
    "    ## 图片预处理过程\n",
    "    def transform(self, data, ops=None):\n",
    "        \"\"\" transform \"\"\"\n",
    "        if ops is None:\n",
    "            ops = []\n",
    "        for op in ops:\n",
    "            data = op(data)\n",
    "            if data is None:\n",
    "                return None\n",
    "        return data\n",
    "\n",
    "    def create_operators(self, op_param_list, global_config=None):\n",
    "        \"\"\"\n",
    "        create operators based on the config\n",
    "\n",
    "        Args:\n",
    "            params(list): a dict list, used to create some operators\n",
    "        \"\"\"\n",
    "        assert isinstance(op_param_list, list), ('operator config should be a list')\n",
    "        ops = []\n",
    "        for operator in op_param_list:\n",
    "            assert isinstance(operator,\n",
    "                              dict) and len(operator) == 1, \"yaml format error\"\n",
    "            op_name = list(operator)[0]\n",
    "            param = {} if operator[op_name] is None else operator[op_name]\n",
    "            if global_config is not None:\n",
    "                param.update(global_config)\n",
    "            op = eval(op_name)(**param)\n",
    "            ops.append(op)\n",
    "        return ops\n",
    "\n",
    "    ### 检测框的后处理\n",
    "    def order_points_clockwise(self, pts):\n",
    "        \"\"\"\n",
    "        reference from: https://github.com/jrosebr1/imutils/blob/master/imutils/perspective.py\n",
    "        # sort the points based on their x-coordinates\n",
    "        \"\"\"\n",
    "        xSorted = pts[np.argsort(pts[:, 0]), :]\n",
    "\n",
    "        # grab the left-most and right-most points from the sorted\n",
    "        # x-roodinate points\n",
    "        leftMost = xSorted[:2, :]\n",
    "        rightMost = xSorted[2:, :]\n",
    "\n",
    "        # now, sort the left-most coordinates according to their\n",
    "        # y-coordinates so we can grab the top-left and bottom-left\n",
    "        # points, respectively\n",
    "        leftMost = leftMost[np.argsort(leftMost[:, 1]), :]\n",
    "        (tl, bl) = leftMost\n",
    "\n",
    "        rightMost = rightMost[np.argsort(rightMost[:, 1]), :]\n",
    "        (tr, br) = rightMost\n",
    "\n",
    "        rect = np.array([tl, tr, br, bl], dtype=\"float32\")\n",
    "        return rect\n",
    "\n",
    "    def clip_det_res(self, points, img_height, img_width):\n",
    "        for pno in range(points.shape[0]):\n",
    "            points[pno, 0] = int(min(max(points[pno, 0], 0), img_width - 1))\n",
    "            points[pno, 1] = int(min(max(points[pno, 1], 0), img_height - 1))\n",
    "        return points\n",
    "    \n",
    "    #shape_part_list = [661 969 7.74583964e-01 6.60474716e-01]\n",
    "    def filter_tag_det_res(self, dt_boxes, shape_part_list):\n",
    "        img_height, img_width = shape_part_list[0],shape_part_list[1]\n",
    "        dt_boxes_new = []\n",
    "        for box in dt_boxes:\n",
    "            box = self.order_points_clockwise(box)\n",
    "            box = self.clip_det_res(box, img_height, img_width) \n",
    "            rect_width = int(np.linalg.norm(box[0] - box[1]))\n",
    "            rect_height = int(np.linalg.norm(box[0] - box[3]))\n",
    "            if rect_width <= 3 or rect_height <= 3:\n",
    "                continue\n",
    "            dt_boxes_new.append(box)\n",
    "        dt_boxes = np.array(dt_boxes_new)\n",
    "        return dt_boxes\n",
    "\n",
    "    ### 定义图片前处理过程，和检测结果后处理过程\n",
    "    def get_process(self):\n",
    "        det_db_thresh = 0.3\n",
    "        det_db_box_thresh = 0.3\n",
    "        max_candidates = 2000\n",
    "        unclip_ratio = 1.6\n",
    "        use_dilation = True\n",
    "        # DetResizeForTest 定义检测模型前处理规则\n",
    "        pre_process_list = [{\n",
    "            'DetResizeForTest': {\n",
    "                # 'limit_side_len': 2500,\n",
    "                # 'limit_type': 'max',\n",
    "                'resize_long': 640\n",
    "                # 'image_shape':[640,640],\n",
    "                # 'keep_ratio':True,\n",
    "            }\n",
    "        }, {\n",
    "            'NormalizeImage': {\n",
    "                'std': [0.229, 0.224, 0.225],\n",
    "                'mean': [0.485, 0.456, 0.406],\n",
    "                'scale': '1./255.',\n",
    "                'order': 'hwc'\n",
    "            }\n",
    "        }, {\n",
    "            'ToCHWImage': None\n",
    "        }, {\n",
    "            'KeepKeys': {\n",
    "                'keep_keys': ['image', 'shape']\n",
    "            }\n",
    "        }]\n",
    "\n",
    "        infer_before_process_op = self.create_operators(pre_process_list)\n",
    "        det_re_process_op = DBPostProcess(det_db_thresh, det_db_box_thresh, max_candidates, unclip_ratio, use_dilation)\n",
    "        return infer_before_process_op, det_re_process_op\n",
    "\n",
    "    def sorted_boxes(self, dt_boxes):\n",
    "        \"\"\"\n",
    "        Sort text boxes in order from top to bottom, left to right\n",
    "        args:\n",
    "            dt_boxes(array):detected text boxes with shape [4, 2]\n",
    "        return:\n",
    "            sorted boxes(array) with shape [4, 2]\n",
    "        \"\"\"\n",
    "        num_boxes = dt_boxes.shape[0]\n",
    "        sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))\n",
    "        _boxes = list(sorted_boxes)\n",
    "\n",
    "        for i in range(num_boxes - 1):\n",
    "            if abs(_boxes[i + 1][0][1] - _boxes[i][0][1]) < 10 and \\\n",
    "                    (_boxes[i + 1][0][0] < _boxes[i][0][0]):\n",
    "                tmp = _boxes[i]\n",
    "                _boxes[i] = _boxes[i + 1]\n",
    "                _boxes[i + 1] = tmp\n",
    "        return _boxes\n",
    "\n",
    "    ### 图像输入预处理\n",
    "    def resize_norm_img(self, img):\n",
    "        imgC, imgH, imgW = [int(v) for v in self.model_shape]\n",
    "        assert imgC == img.shape[2]\n",
    "        h, w = img.shape[:2]\n",
    "        ratio = w / float(h)\n",
    "        if math.ceil(imgH * ratio) > imgW:\n",
    "            resized_w = imgW\n",
    "        else:\n",
    "            resized_w = int(math.ceil(imgH * ratio))\n",
    "        resized_image = cv2.resize(img, (resized_w, imgH))\n",
    "        resized_image = resized_image.astype('float32')\n",
    "        resized_image = resized_image.transpose((2, 0, 1)) / 255\n",
    "        resized_image -= 0.5\n",
    "        resized_image /= 0.5\n",
    "        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)\n",
    "        padding_im[:, :, 0:resized_w] = resized_image\n",
    "        return padding_im\n",
    "\n",
    "    ## 推理检测图片中的部分\n",
    "    def get_boxes(self):\n",
    "        img_ori = self.img\n",
    "        img_part = img_ori.copy()\n",
    "        data_part = {'image': img_part}\n",
    "        data_part = self.transform(data_part, self.infer_before_process_op)\n",
    "        img_part, shape_part_list = data_part\n",
    "        img_part = np.expand_dims(img_part, axis=0)\n",
    "        shape_part_list = np.expand_dims(shape_part_list, axis=0)\n",
    "        if self.use_dnn == True:\n",
    "            self.onet_det_session.setInput(img_part)  \n",
    "            outs_part = self.onet_det_session.forward()\n",
    "        else:\n",
    "            inputs_part = {self.onet_det_session.get_inputs()[0].name: img_part}\n",
    "            outs_part = self.onet_det_session.run(None, inputs_part)\n",
    "            outs_part = outs_part[0]\n",
    "        #print(outs_part.shape)\n",
    "        post_res_part = self.det_re_process_op(outs_part, shape_part_list)\n",
    "        dt_boxes_part = post_res_part[0]['points']\n",
    "        dt_boxes_part = self.filter_tag_det_res(dt_boxes_part,shape_part_list[0])\n",
    "        dt_boxes_part = self.sorted_boxes(dt_boxes_part)\n",
    "\n",
    "        return dt_boxes_part,img_part\n",
    "\n",
    "    ### 根据bounding box得到单元格图片\n",
    "    def get_rotate_crop_image(self, img, points):\n",
    "        img_crop_width = int(\n",
    "            max(\n",
    "                np.linalg.norm(points[0] - points[1]),\n",
    "                np.linalg.norm(points[2] - points[3])))\n",
    "        img_crop_height = int(\n",
    "            max(\n",
    "                np.linalg.norm(points[0] - points[3]),\n",
    "                np.linalg.norm(points[1] - points[2])))\n",
    "        pts_std = np.float32([[0, 0], [img_crop_width, 0],\n",
    "                              [img_crop_width, img_crop_height],\n",
    "                              [0, img_crop_height]])\n",
    "        M = cv2.getPerspectiveTransform(points, pts_std)\n",
    "        dst_img = cv2.warpPerspective(\n",
    "            img,\n",
    "            M, (img_crop_width, img_crop_height),\n",
    "            borderMode=cv2.BORDER_REPLICATE,\n",
    "            flags=cv2.INTER_CUBIC)\n",
    "        dst_img_height, dst_img_width = dst_img.shape[0:2]\n",
    "        if dst_img_height * 1.0 / dst_img_width >= 1.5:\n",
    "            dst_img = np.rot90(dst_img)\n",
    "        return dst_img\n",
    "\n",
    "    ### 单张图片推理\n",
    "    def get_img_res(self, onnx_model, img, process_op):\n",
    "        img = self.resize_norm_img(img)\n",
    "        img = img[np.newaxis, :]\n",
    "        if self.use_dnn:\n",
    "            onnx_model.setInput(img)  # 设置模型输入\n",
    "            outs = onnx_model.forward()  # 推理出结果\n",
    "        else:\n",
    "            inputs = {onnx_model.get_inputs()[0].name: img}\n",
    "            outs = onnx_model.run(None, inputs)\n",
    "            outs = outs[0]\n",
    "        return process_op(outs)\n",
    "\n",
    "    def recognition_img(self, dt_boxes):\n",
    "        img_ori = self.img  #原图大小\n",
    "        img = img_ori.copy()\n",
    "        img_list = []\n",
    "        for box in dt_boxes[0]:\n",
    "            tmp_box = copy.deepcopy(box)\n",
    "            img_crop = self.get_rotate_crop_image(img, tmp_box)\n",
    "            img_list.append(img_crop)\n",
    "\n",
    "        ## 识别小图片\n",
    "        results = []\n",
    "        results_info = []\n",
    "        for pic in img_list:\n",
    "            res = self.get_img_res(self.onet_rec_session, pic, self.postprocess_op)\n",
    "            results.append(res[0])\n",
    "            results_info.append(res)\n",
    "        return results, results_info\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "ename": "AttributeError",
     "evalue": "module 'cv2.dnn' has no attribute 'readNetFromONNX'",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_1964\\2969005901.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      4\u001b[0m     \u001b[1;31m# 文本检测\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m     \u001b[1;31m# 模型固化为640*640 需要修改对应前处理，box的后处理。\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m     \u001b[0mocr_sys\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mdet_rec_functions\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mimage\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0muse_dnn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      7\u001b[0m     \u001b[1;31m# 得到检测框\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      8\u001b[0m     \u001b[0mdt_boxes\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mocr_sys\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_boxes\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_1964\\804329560.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, image, use_dnn)\u001b[0m\n\u001b[0;32m     10\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0monet_rec_session\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0monnxruntime\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mInferenceSession\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msmall_rec_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     11\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 12\u001b[1;33m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0monet_det_session\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadNetFromONNX\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdet_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     13\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0monet_rec_session\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadNetFromONNX\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msmall_rec_file\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     14\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfer_before_process_op\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdet_re_process_op\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_process\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mAttributeError\u001b[0m: module 'cv2.dnn' has no attribute 'readNetFromONNX'"
     ]
    }
   ],
   "source": [
    "if __name__=='__main__':\n",
    "    # 读取图片\n",
    "    image = cv2.imread('./00.png')\n",
    "    # 文本检测\n",
    "    # 模型固化为640*640 需要修改对应前处理，box的后处理。\n",
    "    ocr_sys = det_rec_functions(image,use_dnn = True)\n",
    "    # 得到检测框\n",
    "    dt_boxes = ocr_sys.get_boxes()\n",
    "    # 识别 results: 单纯的识别结果，results_info: 识别结果+置信度    原图\n",
    "    # 识别模型固定尺寸只能100长度，需要处理可以根据自己场景导出模型 1000\n",
    "    # onnx可以支持动态，不受限\n",
    "    results, results_info = ocr_sys.recognition_img(dt_boxes)\n",
    "    print(f'opencv dnn :{str(results)}')\n",
    "    print('------------------------------')\n",
    "    ocr_sys = det_rec_functions(image,use_dnn = False)\n",
    "    # 得到检测框\n",
    "    dt_boxes = ocr_sys.get_boxes()\n",
    "    # 识别 results: 单纯的识别结果，results_info: 识别结果+置信度    原图\n",
    "    # 识别模型固定尺寸只能100长度，需要处理可以根据自己场景导出模型 1000\n",
    "    # onnx可以支持动态，不受限\n",
    "    results, results_info = ocr_sys.recognition_img(dt_boxes)\n",
    "    print(f'onnxruntime :{str(results)}')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.6.13 ('paddle')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "8400e034ae209d81f9114f403eb7f9856a4c9161d3b220bef5747c4b976a2b28"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
